clear all;close all;clc
fs = 44.1e3;
f1 = 200;
f2 = 1000;
f3 = 12000;

% Conditions I & II:

% if Figure 1 (Entire Brain)
% xv = 2:2:24
% yv = 1:5:50
% if Figure 2 (Frontal Hemisphere)
% xv = 2:2:24
% yv = 1:5:25
% if Figure 3 (Rear Hemisphere)
% xv = 2:2:24
% yv = 25:5:50
folder = 'standard_hrir_database/'; %use this one if on Mac
SN = {'003';'008';'009';'010';'011';'012';'015';'017';'018';'019';'020';'021';'027';...
    '028';'033';'040';'044';'048';'051';'058';'059';'060';'061';'065';'119';'124';...
    '126';'127';'131';'133';'134';'135';'137';'147';'148';'152';'153';'154';'155';'156'};

xv = 2:2:24; % horizontal 2-24 locations % 12 overall

x_loc = [-80:15:-65 -55 -45:5:45 55 65:15:80]; %Azimuths
x_loc(xv)
%ser0
y_loc = -45:5.625:230.625; %Elevations
ploc = f_heatmap_subplot(3,3);
for One1_Two2 = 1:3 % 1, binaural HRTFs
    One1_Two2
    % One1_Two2 = 2; % 2, bHRTF, then monaural HRTF (left only)
    % One1_Two2 = 3; % 3, concatenated HRTFs (left and right)
    fig=figure('position',[50 50 700 800]);
    cfig = 0;
    for whole0_front1_back2 = 0:2
        whole0_front1_back2
        %folder = 'standard_hrir_database/';
        X_ILD = [];X_ITD = [];
        d = [];
        c_row = 0;

        % % %xv = 7:1:18; % horizontal 2-24 locations
        if whole0_front1_back2 == 0
            yv = 1:5:50; % vertical 1-50 locations; 10 overall
        elseif whole0_front1_back2 == 1
            yv = 1:5:25; % vertical front locations
        elseif whole0_front1_back2 == 2
            yv = 25:5:50; % vertical back locations
        end

%         length(xv)
%         length(yv)
        NC = length(xv)*length(yv);
        c = 0;
        for azn=xv
            c_row=c_row+1;
            c_col = 0;

            for eln=yv
                c_col=c_col+1;
                c = c+1;
                if One1_Two2 == 1
                    Binaural_vectors{c} = [];
                elseif One1_Two2 == 2
                    Binaural_vectors{c_row,c_col} = [];
                    Monaural_vectors{c_row,c_col} = [];
                elseif One1_Two2 == 3
                    Monaural_vectors{c} = [];
                end
            end
        end

        for isub = 1:length(SN)
            c = 0;
            c_row = 0;
            eval(['load ' folder 'subject_' SN{isub} '/hrir_final'])
            for azn=xv
                c_col = 0;
                c_row=c_row+1;
                temp_vect = [];
                temp_mon = [];
                temp_LR = [];
                for eln=yv
                    c_col=c_col+1;
                    h_l = squeeze(hrir_l(azn,eln,:));
                    h_r = squeeze(hrir_r(azn,eln,:));
                    Nfft = length(h_l);
                    fsig = (1:Nfft)./Nfft.*fs;
                    N1 = find(fsig>f1);N1 = N1(1);
                    N2 = find(fsig<=f2);N2 = N2(end);
                    N3 = find(fsig<=f3);N3 = N3(end);
                    yl = abs(fft(h_l,Nfft))';
                    yr = abs(fft(h_r,Nfft))';
                    phasel = angle(fft(h_l,Nfft))';
                    phaser = angle(fft(h_r,Nfft))';

                    if One1_Two2==1
                        phrtf =  [(phaser(N1:N2)-phasel(N1:N2)) yr(N2:N3)-yl(N2:N3)];
                        temp_vect=[temp_vect
                            phrtf];
                    elseif One1_Two2==2
                        phrtf =  [(phaser(N1:N2)-phasel(N1:N2)) yr(N2:N3)-yl(N2:N3)];
                        temp_vect=[temp_vect
                            phrtf];
                        temp_mon=[temp_mon
                            yl(N1:N3)]; % in the paper, only left
                        % % %                     yl(N1:N3) yr(N1:N3)]; % both
                    elseif One1_Two2==3
                        temp_LR = [temp_LR
                            yl(N1:N3) yr(N1:N3)];
                    end
                end % elevation finished
                if One1_Two2==1 % do the correction
                    temp_vect=f_HRTF_wrap(temp_vect);
                    m = 0;
                    for kk = c+1:c+length(yv)
                        m = m+1;
                        Binaural_vectors{kk}=[Binaural_vectors{kk}
                            temp_vect(m,:)];
                    end
                    c = c+length(yv);
                elseif One1_Two2 == 2
                    temp_vect = f_HRTF_wrap(temp_vect);
                    m = 0;

                    for new_col = 1:length(yv)
                        m = m+1;
                        Binaural_vectors{c_row,new_col}=[Binaural_vectors{c_row,new_col}
                            temp_vect(m,:)];
                        Monaural_vectors{c_row,new_col}=[Monaural_vectors{c_row,new_col}
                            temp_mon(m,:)];
                    end
                elseif One1_Two2 == 3
                    m=0;
                    for kk = c+1:c+length(yv)
                        m = m+1;
                        Monaural_vectors{kk}=[Monaural_vectors{kk}
                            temp_LR(m,:)];
                    end
                    c = c+length(yv);
                end % azn loop

            end
            %     =====================subtraction
            if One1_Two2<3
                get_avg_binaural = [];
                for irow=1:size(Binaural_vectors,1)
                    for icol = 1:size(Binaural_vectors,2)
                        get_avg_binaural=[get_avg_binaural;Binaural_vectors{irow,icol}(end,:)];
                    end
                end
                for irow=1:size(Binaural_vectors,1)
                    for icol = 1:size(Binaural_vectors,2)
                        Binaural_vectors{irow,icol}(end,:)=Binaural_vectors{irow,icol}(end,:)-mean(get_avg_binaural,1);
                    end
                end
            end
            if One1_Two2>1
                %     =====================subtraction
                get_avg_monaural = [];
                for irow=1:size(Monaural_vectors,1)
                    for icol = 1:size(Monaural_vectors,2)
                        get_avg_monaural=[get_avg_monaural;Monaural_vectors{irow,icol}(end,:)];
                    end
                end
                for irow=1:size(Monaural_vectors,1)
                    for icol = 1:size(Monaural_vectors,2)
                        Monaural_vectors{irow,icol}(end,:)=Monaural_vectors{irow,icol}(end,:)-mean(get_avg_monaural,1);
                    end
                end
            end
        end % isub


        if One1_Two2 == 1
            HRTF_cues{1} = Binaural_vectors;
        elseif One1_Two2 == 2
            HRTF_cues{1} = Binaural_vectors;
            HRTF_cues{2} = Monaural_vectors;
        elseif One1_Two2 == 3
            HRTF_cues{1} = Monaural_vectors;
        end

        out1 = f_LOOCV(One1_Two2,HRTF_cues,length(SN));
        CM = out1.CM;

        H_err = 0;
        V_err = 0;
        cou = 0;
        H_forstd=[];
        V_forstd=[];
        for i_true = 1:length(CM)
            for j_class = 1:length(CM)
                out_true = f_get_xandy(i_true,xv,yv);
                out_class = f_get_xandy(j_class,xv,yv);
                H_indi =  abs(out_true.xdegree - out_class.xdegree);
                if  H_indi> 180
                    H_indi = 360-H_indi;
                end
                V_indi =  abs(out_true.ydegree - out_class.ydegree);
                if  V_indi> 180
                    V_indi = 360-V_indi;
                end
                H_err = H_err +H_indi*CM(i_true,j_class);
                V_err = V_err + V_indi*CM(i_true,j_class);
                cou = cou+1;
                for itimes = 1:CM(i_true,j_class)
                    H_forstd = [H_forstd H_indi];
                    V_forstd = [V_forstd V_indi];
                end

            end
        end
        H_err = H_err/(length(SN) )/length(CM);
        V_err = V_err/(length(SN))/length(CM);

        mean_H = mean(H_forstd)
        STD_H = std(H_forstd)
        mean_V = mean(V_forstd)
        STD_V = std(V_forstd)

        %  One1_Two2 = 1; % 2 plot dual HRTFs
        %  whole0_front1_back2 = 0;
        eval(['save Cond_' num2str(One1_Two2) '_WFB_'  num2str(whole0_front1_back2)    ' H_forstd V_forstd']);


        % +++++++++++++ Plotting now:

        maxEs = [10 90];
        for icond=1:2
            figure(fig)
            cfig=cfig+1;
            subplot('position',ploc(cfig,:)) % error dots
            maxE = maxEs(icond);
            for i = 1:length(CM)
                V_err = 0;
                H_err = 0;
                out = f_get_xandy(i,xv,yv);
                x_true = out.xdegree; y_true = out.ydegree;
                for j = 1:length(CM)
                    out = f_get_xandy(j,xv,yv);
                    x_class = out.xdegree; y_class = out.ydegree;
                    H_indi=abs(x_true-x_class);
                    if  H_indi> 180
                        H_indi = 360-H_indi;
                    end
                    V_indi =  abs(y_true-y_class);
                    if  V_indi> 180
                        V_indi = 360-V_indi;
                    end
                    H_err=H_err+H_indi.*CM(i,j);
                    V_err=V_err+V_indi.*CM(i,j);
                end
                V_err = V_err/(length(CM));
                H_err = H_err/(length(CM));
                outxyz = f_3Dconvert(x_true,y_true);
                s = plot3(outxyz.x,outxyz.y,outxyz.z,'ok');
                xlabel('x');
                ylabel('y');
                zlabel('z');
                if icond==1
                    pcc = [1 1 1]-min([H_err maxE])/maxE.*[1 1 1];
                else
                    pcc = [1 1 1]-min([V_err maxE])/maxE.*[1 1 1];
                end
                %     pcc = [1 1 1]-H_err/maxE.*[1 1 1];
                set(s,'markerfacecolor',pcc,'markersize',6,'color',[0.8 0.8 0.8]); hold on;
                hold on;
            end
            set(gca,'xtick',-1:0.5:1,'ytick',-1:0.5:1,'ztick',-1:0.5:1)
            grid on;
            if icond==1
                title('Horizontal Errors');
            else
                title('Vertical Errors');
            end
            xlim([-1 1])
            ylim([-1 1])
            zlim([-1 1])
            axis square
        end

        % for plotting confusion matrix:
        figure(fig)
            cfig=cfig+1;
            subplot('position',ploc(cfig,:)) % error dots
        for i=1:NC
            for j=1:NC
                if CM(i,j)>0
                    s=plot(j,i,'s');hold on;
                    set(s,'color','none','markersize',4,'markerfacecolor',(max(max(CM))+1-CM(i,j))./max(max(CM)).*[1 1 1])
                end
            end
        end
        set(gca,'xtick',0:10:NC,'ytick',0:10:NC)
        xlabel('Classified Speaker Index');
        ylabel('True Speaker Index');
        title('Confusion Matrix')
        axis square
        grid on
        clear Binaural_vectors Monaural_vectors
    end
    exportgraphics(gcf, ['Figure_345_Condition' num2str(One1_Two2) '.png'], 'Resolution', 300,'ContentType', 'vector')
end